{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from io import open\n",
    "import glob\n",
    "import unicodedata\n",
    "import string\n",
    "import math\n",
    "import os\n",
    "import time\n",
    "import torch.nn as nn\n",
    "import torch\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.utils.data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_dir = os.path.dirname(os.path.abspath(__name__))\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "all_letters = string.ascii_letters + \" .,;'\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_letters = len(all_letters) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def unicodeToAscii(s):\n",
    "    return ''.join(\n",
    "        c for c in unicodedata.normalize('NFD', s)\n",
    "        if unicodedata.category(c) != 'Mn'\n",
    "        and c in all_letters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def readLines(filename):\n",
    "    lines = open(filename, encoding='utf-8').read().strip().split('\\n')\n",
    "    return [unicodeToAscii(line) for line in lines]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_dir = \"/mnt/data/name_data/names\"\n",
    "path_txt = os.path.join(file_dir,\"*.txt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "category_lines = {}\n",
    "all_categories = []\n",
    "for filename in glob.glob(path_txt):\n",
    "    category = os.path.splitext(os.path.basename(filename))[0]\n",
    "    all_categories.append(category)\n",
    "    lines = readLines(filename)\n",
    "    category_lines[category] = lines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Russian',\n",
       " 'Polish',\n",
       " 'Czech',\n",
       " 'Chinese',\n",
       " 'Irish',\n",
       " 'Japanese',\n",
       " 'French',\n",
       " 'Arabic',\n",
       " 'Dutch',\n",
       " 'English',\n",
       " 'Italian',\n",
       " 'Vietnamese',\n",
       " 'Portuguese',\n",
       " 'Greek',\n",
       " 'Scottish',\n",
       " 'Korean',\n",
       " 'Spanish',\n",
       " 'German']"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_categories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def letterToIndex(letter):\n",
    "    return all_letters.find(letter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def letterToTensor(letter):\n",
    "    tensor = torch.zeros(1, n_letters)\n",
    "    tensor[0][letterToIndex(letter)] = 1\n",
    "    return tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lineToTensor(line):\n",
    "    tensor = torch.zeros(len(line), 1, n_letters)\n",
    "    for li, letter in enumerate(line):\n",
    "        tensor[li][0][letterToIndex(letter)] = 1\n",
    "    return tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def categoryFromOutput(output):\n",
    "    top_n, top_i = output.topk(1)\n",
    "    category_i = top_i[0].item()\n",
    "    return all_categories[category_i], category_i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def randomChoice(l):\n",
    "    return l[random.randint(0, len(l) - 1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def randomTrainingExample():\n",
    "    category = randomChoice(all_categories)                 # 选类别\n",
    "    line = randomChoice(category_lines[category])           # 选一个样本\n",
    "    category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long)\n",
    "    line_tensor = lineToTensor(line)    # str to one-hot\n",
    "    return category, line, category_tensor, line_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Whelan'"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "category = randomChoice(all_categories)\n",
    "line = randomChoice(category_lines[category])\n",
    "line"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([6, 1, 57])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tensor = torch.zeros(len(line), 1, n_letters)\n",
    "tensor.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0.]],\n",
       "\n",
       "        [[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0.]],\n",
       "\n",
       "        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0.]],\n",
       "\n",
       "        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0.]],\n",
       "\n",
       "        [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0.]],\n",
       "\n",
       "        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0.]]])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long)\n",
    "line_tensor = lineToTensor(line)\n",
    "line_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def timeSince(since):\n",
    "    now = time.time()\n",
    "    s = now - since\n",
    "    m = math.floor(s / 60)\n",
    "    s -= m * 60\n",
    "    return '%dm %ds' % (m, s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNN(nn.Module):\n",
    "    def __init__(self,input_size,hidden_size,output_size):\n",
    "        super(RNN,self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.u = nn.Linear(input_size,hidden_size)\n",
    "        self.w = nn.Linear(hidden_size,hidden_size)\n",
    "        self.v = nn.Linear(hidden_size,output_size)\n",
    "        self.tanh = nn.Tanh()\n",
    "        self.softmax = nn.LogSoftmax(dim=1)\n",
    "    def forward(self,inputs,hidden):\n",
    "        u_x = self.u(inputs)\n",
    "        hidden = self.w(hidden)\n",
    "        hidden = self.tanh(hidden+u_x)\n",
    "        output = self.softmax(self.v(hidden))\n",
    "        return output,hidden\n",
    "    def initHidden(self):\n",
    "        return torch.zeros(1,self.hidden_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_categories = len(all_categories)\n",
    "n_hidden = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RNN(\n",
       "  (u): Linear(in_features=57, out_features=128, bias=True)\n",
       "  (w): Linear(in_features=128, out_features=128, bias=True)\n",
       "  (v): Linear(in_features=128, out_features=18, bias=True)\n",
       "  (tanh): Tanh()\n",
       "  (softmax): LogSoftmax(dim=1)\n",
       ")"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rnn = RNN(n_letters, n_hidden, n_categories)\n",
    "rnn.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.NLLLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rate = 0.005\n",
    "n_iters = 20\n",
    "def get_lr(iter, learning_rate):\n",
    "    lr_iter = learning_rate if iter < n_iters else learning_rate*0.1\n",
    "    return lr_iter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "current_loss = 0\n",
    "all_losses = []\n",
    "start = time.time()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(line_tensor):\n",
    "    hidden = rnn.initHidden()\n",
    "\n",
    "    for i in range(line_tensor.size()[0]):\n",
    "        output, hidden = rnn(line_tensor[i], hidden)\n",
    "\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0.]],\n",
       "\n",
       "        [[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0.]],\n",
       "\n",
       "        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0.]],\n",
       "\n",
       "        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0.]],\n",
       "\n",
       "        [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0.]],\n",
       "\n",
       "        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0.]]])"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "line_tensor\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(input_line, n_predictions=3):\n",
    "    print('\\n> %s' % input_line)\n",
    "    with torch.no_grad():\n",
    "        output = evaluate(lineToTensor(input_line))\n",
    "\n",
    "        # Get top N categories\n",
    "        topv, topi = output.topk(n_predictions, 1, True)\n",
    "\n",
    "        for i in range(n_predictions):\n",
    "            value = topv[0][i].item()\n",
    "            category_index = topi[0][i].item()\n",
    "            print('(%.2f) %s' % (value, all_categories[category_index]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(category_tensor, line_tensor):\n",
    "    hidden = rnn.initHidden()\n",
    "\n",
    "    rnn.zero_grad()\n",
    "\n",
    "    line_tensor = line_tensor.to(device)\n",
    "    hidden = hidden.to(device)\n",
    "    category_tensor = category_tensor.to(device)\n",
    "\n",
    "    for i in range(line_tensor.size()[0]):\n",
    "        output, hidden = rnn(line_tensor[i], hidden)\n",
    "\n",
    "    loss = criterion(output, category_tensor)\n",
    "    loss.backward()\n",
    "\n",
    "    # Add Parameters' gradients to their values, multiplied by learning rate\n",
    "    for p in rnn.Parameters():\n",
    "        p.data.add_(-learning_rate, p.grad.data)\n",
    "\n",
    "    return output, loss.item()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter: 50      time:    0m 0s loss: 2.8364 name:    Maehata  pred:  Russian label: ✗ (Japanese)\n",
      "Iter: 100     time:    0m 0s loss: 3.0079 name:      Haber  pred: Portuguese label: ✗ (German)\n",
      "Iter: 150     time:    0m 0s loss: 3.0448 name:       Baum  pred: Portuguese label: ✗ (German)\n",
      "Iter: 200     time:    0m 0s loss: 2.9760 name:        Xie  pred: Portuguese label: ✗ (Chinese)\n"
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "print_every = 50\n",
    "plot_every = 50\n",
    "n_iters = 200\n",
    "for iter in range(1, n_iters + 1):\n",
    "    # sample\n",
    "    category, line, category_tensor, line_tensor = randomTrainingExample()\n",
    "\n",
    "    # training\n",
    "    output, loss = train(category_tensor, line_tensor)\n",
    "\n",
    "    current_loss += loss\n",
    "\n",
    "    # Print iter number, loss, name and guess\n",
    "    if iter % print_every == 0:\n",
    "        guess, guess_i = categoryFromOutput(output)\n",
    "        correct = '✓' if guess == category else '✗ (%s)' % category\n",
    "        print('Iter: {:<7} time: {:>8s} loss: {:.4f} name: {:>10s}  pred: {:>8s} label: {:>8s}'.format(\n",
    "            iter, timeSince(start), loss, line, guess, correct))\n",
    "\n",
    "    # Add current loss avg to list of losses\n",
    "    if iter % plot_every == 0:\n",
    "        all_losses.append(current_loss / plot_every)\n",
    "        current_loss = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAbyUlEQVR4nO3de3TU9Z3/8ed7JiHhfg3XBAIkwXpFjIiKikCorR50V9t1u/6UVg9eqiJ4zp52/9g97fn909/+AKm2XlbdxfrbqrWtq9RWIzfvYFBAUUkCBAnXECBcE3L5/P7IRJIwIZNkks9cXo9z5mTm+/3MfF+ffMOLTyaTjDnnEBGR+BfwHUBERKJDhS4ikiBU6CIiCUKFLiKSIFToIiIJIsXXgYcNG+ays7N9HV5EJC5t2LDhoHMuI9w+b4WenZ1NUVGRr8OLiMQlM9vZ1j495SIikiBU6CIiCUKFLiKSIFToIiIJQoUuIpIgVOgiIglChS4ikiDirtC/2F3Fr/72NfqzvyIiLcVdoX/6zWGeXLOND0orfUcREYkpcVfo/3B5FqMHprO4cKtW6SIizcRdoaelBHloVi6ffXOE1VsP+I4jIhIz4q7QAW67LJOxQ/qw+O1irdJFREListBTgwEWzMply56jvLVln+84IiIxIS4LHeCWS8cwMaMvSwtLaGjQKl1EJG4LPRgwHpmdx9b9x1jx+V7fcUREvIvbQge48aJRnDeyP48VFlNX3+A7joiIV3Fd6IGAsbAgj+0HT/Daxj2+44iIeBXXhQ4w5/wRXDRmIMtWFlOrVbqIJLG4L3QzY9GcPHYdOsUfisp9xxER8SbuCx1gRl4GU8YO4vFVJVTX1vuOIyLiRUIUupnx6JxJ7K2q5qX13/iOIyLiRUIUOsBVE4cybcIQfrNmG6dOa5UuIsknYQq9aZVecayG331c5juOiEiPS5hCB7g8ewjX5mXw1NrtHK+p8x1HRKRHRVToZlZmZp+b2UYzKwqz38zs12ZWamabzWxK9KNGZlFBHodOnGb5h2W+IoiIeNGRFfr1zrnJzrn8MPu+B+SGLvOBJ6MRrjMmZw1i9neG8/TabVSdqvUVQ0Skx0XrKZebgRdco4+BQWY2KkqP3WELC/I4Wl3Hc+/v8BVBRKTHRVroDnjbzDaY2fww+8cAu5rdLg9t8+KC0QP5/kUjef79HRw+cdpXDBGRHhVpoU93zk2h8amVn5rZtZ05mJnNN7MiMyuqqKjozENE7JHZeZw4XcfT727v1uOIiMSKiArdObc79PEA8Gdgaqshu4GsZrczQ9taP84zzrl851x+RkZG5xJHKG9Ef26+ZDTLPyyj4lhNtx5LRCQWtFvoZtbXzPo3XQfmAF+0GvY6cGfo1S7TgCrnnPc/Ur5gdh6n6xt4cs0231FERLpdJCv0EcD7ZrYJWA/8xTn3NzO7z8zuC415E9gOlAL/ATzQLWk7aPywvtw6ZQwvrtvJvqpq33FERLpVSnsDnHPbgUvCbH+q2XUH/DS60aLjoZm5/Pmz3TyxuoT/fctFvuOIiHSbhPpN0XCyhvThh/lZvPzJLsoPn/QdR0Sk2yR8oQM8ODMHM+PxlaW+o4iIdJukKPRRA3vzT1eM5dVPyyk7eMJ3HBGRbpEUhQ5w/4yJ9AoGWLayxHcUEZFukTSFPrx/OndeNY7XNu6mZP8x33FERKIuaQod4N5rJ9InNchj72iVLiKJJ6kKfUjfXtw9fTx/+XwvX+456juOiEhUJVWhA9x9zQQGpKewpLDYdxQRkahKukIf2DuV+ddO4J2v9rNp1xHfcUREoibpCh1g3tXjGdwnlcVapYtIAknKQu+XlsJ9103k3eIKisoO+Y4jIhIVSVnoAHdemc2wfmksflurdBFJDElb6L17Bfnp9RP5aHslH5Ye9B1HRKTLkrbQAf5x6lhGDUxncWExjX8wUkQkfiV1oaenBnlwZg4bdh5mTXH3viWeiEh3S+pCB/jBZVlkDu7NUq3SRSTOJX2h90oJsGBWLpvLqyj8cr/vOCIinZb0hQ7wd5eOYcKwviwpLKahQat0EYlPKnQgJRhgwexcvt53jDe/8P7e1iIinaJCD7np4tHkjejH0sJi6rVKF5E4pEIPCQaMhbPz2FZxgv/ZuNt3HBGRDlOhN/PdC0ZywegBLFtZQm19g+84IiIdokJvJhAwFhXksbPyJH/cUO47johIh6jQW5l53nAmZw3i8VWl1NTV+44jIhIxFXorZsajc/LYfeQUr3yyy3ccEZGIqdDDmJ4zjKnZQ3h8VSnVtVqli0h8UKGH0bRKP3Cshhc/3uk7johIRFTobbhiwlCm5wzjyTXbOFFT5zuOiEi7VOjnsGhOHpUnTrP8ozLfUURE2qVCP4cpYwcz87zhPL12O0era33HERE5JxV6OxYV5FF1qpbn39/hO4qIyDmp0Ntx4ZiB3HDBSJ57bwdHTp72HUdEpE0q9AgsLMjj+Ok6nnl3u+8oIiJtUqFHYNLI/tx08Wj+68MyKo/X+I4jIhJWxIVuZkEz+8zMVoTZN8/MKsxsY+hyT3Rj+vfI7Fyqa+t5au0231FERMLqyAp9AfDVOfa/7JybHLo828VcMWdiRj/+7tJMXvhoJ/uPVvuOIyJylogK3cwygRuBhCvqjlgwK5f6BsdvV5f6jiIicpZIV+iPAf8MnOuPhN9qZpvN7FUzywo3wMzmm1mRmRVVVFR0MKp/Y4f24Qf5Wfx+/S52HznlO46ISAvtFrqZ3QQccM5tOMewN4Bs59zFQCGwPNwg59wzzrl851x+RkZGpwL79tDMHACeWFXiOYmISEuRrNCvBuaaWRnwEjDTzF5sPsA5V+mca3r5x7PAZVFNGUNGD+rNj64Yyx+KytlZecJ3HBGRb7Vb6M65nzvnMp1z2cDtwCrn3B3Nx5jZqGY353LuH57GvQdmTCQYMJat1CpdRGJHp1+Hbma/NLO5oZsPm9kWM9sEPAzMi0a4WDV8QDp3XjmO1z7bTemB477jiIgAYM45LwfOz893RUVFXo4dDZXHa7jm/6xm5nnDeeJHU3zHEZEkYWYbnHP54fbpN0U7aWi/NH58dTYrNu/l631HfccREVGhd8X8aybSPz2FpYXFvqOIiKjQu2Jgn1TumT6Bt7bs5/PyKt9xRCTJqdC76CfTsxnUJ5UlhVt9RxGRJKdC76L+6ance+1EVm+tYMPOw77jiEgSU6FHwV1XjWNYv15apYuIVyr0KOjTK4X7Z+TwQWklH22r9B1HRJKUCj1K/umKsYwYkMaSwq34em2/iCQ3FXqUpKcGefD6HD4pO8x7JQd9xxGRJKRCj6IfXp7FmEG9Wfy2Vuki0vNU6FGUlhLk4Vk5bCqvYuVXB3zHEZEko0KPsr+fkkn20D4sKSymoUGrdBHpOSr0KEsNBlgwO5cv9x7lb1v2+Y4jIklEhd4N5l4yhpzh/VhaWEy9Vuki0kNU6N0gGDAWzs6j5MBx3ti0x3ccEUkSKvRu8r0LR3LeyP4sW1lCXf253ltbRCQ6VOjdJBAwFhXksePgCf702W7fcUQkCajQu1HB+SO4OHMgy94p4XSdVuki0r1U6N3IrHGVvvvIKV4p2uU7jogkOBV6N7suL4P8cYN5YlUp1bX1vuOISAJToXczM2PRnDz2Ha3mv9d94zuOiCQwFXoPuGriMK6cMJTfrtnGydN1vuOISIJSofeQR+fkcfB4DS98tNN3FBFJUCr0HpKfPYTr8jJ4eu02jtdolS4i0adC70GPzsnj8Mla/vP9Hb6jiEgCUqH3oIszB1Fw/gieeW87VSdrfccRkQSjQu9hiwryOFZdx7Pvb/cdRUQSjAq9h31n1ABuvHgUz7+/g0MnTvuOIyIJRIXuwcLZuZyqrefptdt8RxGRBKJC9yBneH9umTyG5R+VceBYte84IpIgVOiePDwrl9p6x29Xa5UuItGhQvcke1hfbpuSyX+v+4Y9R075jiMiCUCF7tFDs3JwOJ5YXeo7iogkgIgL3cyCZvaZma0Isy/NzF42s1IzW2dm2VFNmaAyB/fh9svH8sonu9h16KTvOCIS5zqyQl8AfNXGvruBw865HGAp8KuuBksWD87MIRgwfr2yxHcUEYlzERW6mWUCNwLPtjHkZmB56PqrwCwzs67HS3wjBqRzx7Rx/PHTcrZXHPcdR0TiWKQr9MeAfwbaeh+1McAuAOdcHVAFDG09yMzmm1mRmRVVVFR0PG2Cun/GRNJSgizTKl1EuqDdQjezm4ADzrkNXT2Yc+4Z51y+cy4/IyOjqw+XMIb1S2Pe1dm8vmkPW/cd8x1HROJUJCv0q4G5ZlYGvATMNLMXW43ZDWQBmFkKMBCojGLOhDf/mgn07ZXCY+8U+44iInGq3UJ3zv3cOZfpnMsGbgdWOefuaDXsdeCu0PXbQmNcVJMmuMF9e/GT6eP56xf7+GJ3le84IhKHOv06dDP7pZnNDd18DhhqZqXAIuBn0QiXbO6ePp6BvVNZWqhVuoh0XEpHBjvn1gBrQtf/tdn2auAH0QyWjAb2TmX+tRP497e28tk3h7l07GDfkUQkjug3RWPMvKuyGdK3F0u0SheRDlKhx5i+aSncf91E3is5yPodh3zHEZE4okKPQXdMG0dG/zT+79tb0c+WRSRSKvQY1LtXkAevz2H9jkN8UKpXf4pIZFToMer2qVmMHpjO4kKt0kUkMir0GJWWEuShWbl89s0RVm894DuOiMQBFXoMu+2yTMYO6cOSwmKt0kWkXSr0GJYaDPDwrFy+2H2Ut7bs9x1HRGKcCj3G3TJ5NBMy+rK0sJiGBq3SRaRtKvQYlxIM8MjsPLbuP8aKz/f6jiMiMUyFHgduumgUk0b057HCYurq2/qT9CKS7FTocSAQMBYW5LH94Ale27jHdxwRiVEq9Djx3QtGcOGYASxbWUytVukiEoYKPU6YGY8WTGLXoVP8oajcdxwRiUEq9DgyY1IGl44dxBOrSqipq/cdR0RijAo9jjSt0vdUVfPS+l2+44hIjFGhx5mrc4ZyxfghPLG6lFOntUoXkTNU6HHGzHh0ziQqjtXw4sc7fccRkRiiQo9DU8cP4ZrcYTy5dhvHa+p8xxGRGKFCj1OPzpnEoROnWf5hme8oIhIjVOhxanLWIGadN5yn126j6lSt7zgiEgNU6HFsYUEeR6vreO79Hb6jiEgMUKHHsQvHDOR7F47k+fd3cPjEad9xRMQzFXqcW1iQx4nTdTz97nbfUUTEMxV6nMsb0Z+5l4xm+YdlVByr8R1HRDxSoSeABbNyqamr56m123xHERGPVOgJYEJGP26dksnvPt7Jvqpq33FExBMVeoJ4eFYuDQ2O36wu9R1FRDxRoSeIrCF9+IfLs3jpk28oP3zSdxwR8UCFnkAenJmDmfH4Sq3SRZKRCj2BjBrYmx9NHcurn5ZTdvCE7zgi0sNU6Anmgesnkho0lq0s8R1FRHqYCj3BDO+fzl1XZvPaxt2U7D/mO46I9KB2C93M0s1svZltMrMtZvaLMGPmmVmFmW0MXe7pnrgSiXuvm0if1CCPvaNVukgyiWSFXgPMdM5dAkwGbjCzaWHGveycmxy6PBvNkNIxQ/r24ifTx/OXz/fy5Z6jvuOISA9pt9Bdo+Ohm6mhi+vWVNJl90yfQP/0FJa+U+w7ioj0kIieQzezoJltBA4Ahc65dWGG3Wpmm83sVTPLimZI6biBfVKZf80ECr/cz6ZdR3zHEZEeEFGhO+fqnXOTgUxgqpld2GrIG0C2c+5ioBBYHu5xzGy+mRWZWVFFRUUXYkskfjx9PIP7pLKkUKt0kWTQoVe5OOeOAKuBG1ptr3TONf2pv2eBy9q4/zPOuXznXH5GRkYn4kpH9EtL4d7rJrK2uIKiskO+44hIN4vkVS4ZZjYodL03UAB83WrMqGY35wJfRTGjdMGdV45jWL80Fr+tVbpIootkhT4KWG1mm4FPaHwOfYWZ/dLM5obGPBx6SeMm4GFgXvfElY7q0yuFB2ZM5KPtlXxYetB3HBHpRuacnxes5Ofnu6KiIi/HTjbVtfXM+Pc1jBncm1fvuxIz8x1JRDrJzDY45/LD7dNviiaB9NQgD87MYcPOw6wt1g+jRRKVCj1J/DA/i8zBvVlSWIyv78pEpHup0JNEr5QAD8/KZXN5FYVf7vcdR0S6gQo9ifz9pWMYP6wvSwqLaWjQKl0k0ajQk0hKMMAjs3P5et8x3vxir+84IhJlKvQkc9PFo8kd3o+lhcXUa5UuklBU6EkmGDAWFuSxreIE/7Nxt+84IhJFKvQkdMMFIzl/1ACWrSyhtr7BdxwRiRIVehIKBIxFBXnsrDzJnz4t9x1HRKJEhZ6kZn1nOJdkDeLXK0upqav3HUdEokCFnqTMjEcL8th95BSvfLLLdxwRiQIVehK7JncYl2cP5vFVpVTXapUuEu9U6EnMzHh0ziQOHKvhxY93+o4jIl2kQk9y0yYM5eqcoTy1dhsnaup8xxGRLlChC4sKJnHw+GmWf1TmO4qIdIEKXbhs3GCun5TBM+9u51h1re84ItJJKnQBGlfpR07W8vz7Zb6jiEgnqdAFgIsyB/LdC0bw7HvbOXLytO84ItIJKnT51sKCPI6fruM/3tvuO4qIdIIKXb513sgB3HjRKP7zgzIqj9f4jiMiHaRClxYemZ1HdW09T63d5juKiHSQCl1ayBnej1suHcMLH+3kwNFq33FEpANU6HKWBbNyqW9w/GZ1qe8oItIBKnQ5y7ihfflBfia/X7+L3UdO+Y4jIhFSoUtYD87MBeCJVSWek4hIpFToEtaYQb35x6lZ/KGonG8qT/qOIyIRUKFLm356fQ7BgLFspVbpIvFAhS5tGj4gnTuvHMefPyun9MBx33FEpB0qdDmn+66bSHpqUKt0kTigQpdzGtovjXlXZfPGpj18ve+o7zgicg4qdGnX/Gsn0D8thaWFxb6jiMg5qNClXYP69OLua8bz1pb9fF5e5TuOiLRBhS4R+cn08Qzqk8qSwq2+o4hIG1ToEpEB6anMv3YCq7dWsGHnYd9xRCSMdgvdzNLNbL2ZbTKzLWb2izBj0szsZTMrNbN1ZpbdLWnFq7uuzGZo315apYvEqEhW6DXATOfcJcBk4AYzm9ZqzN3AYedcDrAU+FVUU0pM6JuWwv0zJvJBaSUfb6/0HUdEWmm30F2jpt8qSQ1dXKthNwPLQ9dfBWaZmUUtpcSMO6aNY8SANJa8XYxzrb8MRMSniJ5DN7OgmW0EDgCFzrl1rYaMAXYBOOfqgCpgaJjHmW9mRWZWVFFR0aXg4kd6apAHr89hfdkh3is56DuOiDQTUaE75+qdc5OBTGCqmV3YmYM5555xzuU75/IzMjI68xASA354eRZjBvVmcaFW6SKxpEOvcnHOHQFWAze02rUbyAIwsxRgIKAnWRNUWkqQh2bmsGnXEVZ+dcB3HBEJieRVLhlmNih0vTdQAHzdatjrwF2h67cBq5yWbgnt1ssyGTe0D0sKi2lo0KkWiQWRrNBHAavNbDPwCY3Poa8ws1+a2dzQmOeAoWZWCiwCftY9cSVWpAYDLJiVy5d7j/LWln2+44gIYL4W0vn5+a6oqMjLsSU66hscc5auJWDG3x65lmBAL2wS6W5mtsE5lx9un35TVDotGDAWFuRRcuA4Kzbv8R1HJOml+A4g8e37F47ivJGlPPZOCTdeNIqUYGKtEZq+g3Wu8Zcvvr397TZH829yW29ruo9rtp8IxrjGQWdtcy22tX3s5mMaf8TR+LEhdJ+mj99ep2lb0/4z93Ut9je7L46GhjP3b37fs7c1e6yzjtnyWA3NPtcNrR4ToKHhzH0bQieioSlP07gWx2x231aPSas5nHXfsz4vLT8/Hfu8nHnM+6+bwA0XjurgV2P7VOjSJYGAsaggj/m/28D1i9eQGgx82ybNi+pMeZ0pkiZtjqH5uDBlFtrfvGRpawwti5kw21qXtcQHMzAgYEYgdCNgYBgBo9m2xtsW+kiz/dY0DggEQts4s6/pPsaZsc0/ntkfJocFvn1MQvt7pXTPwkeFLl1WcP4I7r1uAuWHTjVuCH1hW+gfBZz5R9d8G83+gbQc17it6TbnGHNmm7XY1/KYZ8ZYs+O2HtN6W9MdjTC5mrY1+4Xo8Lla3q95jpaP1fJ+NBvT5nxCn+im+7fM1fJ+TSVjoZIJV2ati6up/FoW2Jn8zR+z6XjNy7D5fVuUYaB1WZ6jQK1ZKdMyb9N+OUOFLl1mZvz8e9/xHUMk6SXWE54iIklMhS4ikiBU6CIiCUKFLiKSIFToIiIJQoUuIpIgVOgiIglChS4ikiC8/bVFM6sAdnby7sOARHn/M80l9iTKPEBziVVdmcs451zYt3zzVuhdYWZFbf35yHijucSeRJkHaC6xqrvmoqdcREQShApdRCRBxGuhP+M7QBRpLrEnUeYBmkus6pa5xOVz6CIicrZ4XaGLiEgrKnQRkQQR04VuZjeY2VYzKzWzn4XZn2ZmL4f2rzOzbA8xIxLBXOaZWYWZbQxd7vGRsz1m9ryZHTCzL9rYb2b269A8N5vZlJ7OGKkI5jLDzKqanZN/7emMkTCzLDNbbWZfmtkWM1sQZkxcnJcI5xIv5yXdzNab2abQXH4RZkx0O8x9+waosXUBgsA2YALQC9gEnN9qzAPAU6HrtwMv+87dhbnMA57wnTWCuVwLTAG+aGP/94G/0vhuaNOAdb4zd2EuM4AVvnNGMI9RwJTQ9f5AcZivr7g4LxHOJV7OiwH9QtdTgXXAtFZjotphsbxCnwqUOue2O+dOAy8BN7caczOwPHT9VWCWxeabDEYyl7jgnHsXOHSOITcDL7hGHwODzCz6b28eBRHMJS445/Y65z4NXT8GfAWMaTUsLs5LhHOJC6HP9fHQzdTQpfWrUKLaYbFc6GOAXc1ul3P2if12jHOuDqgChvZIuo6JZC4At4a+HX7VzLJ6JlrURTrXeHFl6Fvmv5rZBb7DtCf0LfulNK4Gm4u783KOuUCcnBczC5rZRuAAUOica/O8RKPDYrnQk80bQLZz7mKgkDP/a4s/n9L4dzMuAR4HXvMb59zMrB/wR+AR59xR33m6op25xM15cc7VO+cmA5nAVDO7sDuPF8uFvhtovkrNDG0LO8bMUoCBQGWPpOuYdufinKt0ztWEbj4LXNZD2aItkvMWF5xzR5u+ZXbOvQmkmtkwz7HCMrNUGgvw/znn/hRmSNycl/bmEk/npYlz7giwGrih1a6odlgsF/onQK6ZjTezXjT+wOD1VmNeB+4KXb8NWOVCP12IMe3OpdXzmXNpfO4wHr0O3Bl6VcU0oMo5t9d3qM4ws5FNz2ea2VQa/73E3IIhlPE54Cvn3JI2hsXFeYlkLnF0XjLMbFDoem+gAPi61bCodlhKZ+/Y3ZxzdWb2IPAWja8Sed45t8XMfgkUOedep/HE/87MSmn84dbt/hK3LcK5PGxmc4E6Gucyz1vgczCz39P4KoNhZlYO/BuNP+zBOfcU8CaNr6goBU4CP/aTtH0RzOU24H4zqwNOAbfH6ILhauB/AZ+Hnq8F+BdgLMTdeYlkLvFyXkYBy80sSON/Oq8451Z0Z4fpV/9FRBJELD/lIiIiHaBCFxFJECp0EZEEoUIXEUkQKnQRkQShQhcRSRAqdBGRBPH/AT+tH/z1GnfVAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(all_losses)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTM(nn.Module):\n",
    "    def __init__(self,input_size,hidden_size):\n",
    "        super().__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.U_i = nn.Parameter(torch.Tensor(input_size,hidden_size))\n",
    "        self.V_i = nn.Parameter(torch.Tensor(hidden_size,hidden_size))\n",
    "        self.b_i = nn.Parameter(torch.Tensor(hidden_size))\n",
    "\n",
    "        self.U_f = nn.Parameter(torch.Tensor(input_size,hidden_size))\n",
    "        self.V_f = nn.Parameter(torch.Tensor(input_size,hidden_size))\n",
    "        self.b_f = nn.Parameter(torch.Tensor(hidden_size))\n",
    "\n",
    "        self.U_c = nn.Parameter(torch.Tensor(input_size,hidden_size))\n",
    "        self.V_c = nn.Parameter(torch.Tensor(hidden_size,hidden_size))\n",
    "        self.b_c = nn.Parameter(torch.Tensor(hidden_size))\n",
    "\n",
    "        self.U_o = nn.Parameter(torch.Tensor(input_size,hidden_size))\n",
    "        self.V_o = nn.Parameter(torch.Tensor(hidden_size,hidden_size))\n",
    "        self.b_o = nn.Parameter(torch.Tensor(hidden_size))\n",
    "\n",
    "        self.init_weights()\n",
    "\n",
    "    def init_weights(self):\n",
    "        stdv = 1.0 /math.sqrt(self.hidden_size)\n",
    "        for weight in self.parameters():\n",
    "            weight.data.uniform_(-stdv,stdv)\n",
    "    def forward(self,x,init_states = None):\n",
    "        bs,seq_sz,_ = x.size()\n",
    "        hidden_seq =[]\n",
    "\n",
    "        if init_states is None:\n",
    "            h_t, c_t =(torch.zeros(bs,self.hidden_size).to(x.device),\n",
    "            torch.zeros(bs,self.hidden_size).to(x.device))\n",
    "\n",
    "        else:\n",
    "            h_t,c_t = init_states\n",
    "        \n",
    "        for t in range(seq_sz):\n",
    "            x_t = x[:,t,:]\n",
    "            i_t = torch.sigmoid(x_t@self.U_i+h_t@self.V_i + self.b_i)\n",
    "            f_t = torch.sigmoid(x_t@self.U_f + h_t@ self.V_f + self.b_f)\n",
    "            g_t = torch.tanh(x_t @ self.U_c+h_t*self.V_c+self.b_c)\n",
    "            o_t = torch.sigmoid(x_t@self.U_o + h_t@self.V_o + self.b_o)\n",
    "            c_t = f_t * c_t + i_t * g_t\n",
    "            h_t = o_t * torch.tanh(c_t)\n",
    "            hidden_seq.append(h_t.unsqueeze(0))\n",
    "        hidden_seq = torch.cat(hidden_seq,dim =0)\n",
    "        hidden_seq = torch.transpose(0,1).contiguous()\n",
    "        return hidden_seq,(h_t,c_t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "input = torch.randn(5,3,10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "lstm = LSTM(10,512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LSTM()"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lstm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "U_i = nn.Parameter(torch.Tensor(10,512))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset,DataLoader\n",
    "from torch.nn.utils.rnn import pack_padded_sequence,pack_sequence,pad_sequence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Mydata(Dataset):\n",
    "    def __init__(self,data):\n",
    "        self.data = data\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        return self.data[index]\n",
    "\n",
    "def collate_fn(data):\n",
    "    data.sort(key = lambda x: len(x),reverse = True)\n",
    "    data = pad_sequence(data,batch_first=True,padding_value=0)\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = torch.tensor([1,2,3,4])\n",
    "b = torch.tensor([5,6,7])\n",
    "c = torch.tensor([7,8])\n",
    "d = torch.tensor([9])\n",
    "train_x = [a,b,c,d]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = Mydata(train_x)\n",
    "data_loader = DataLoader(data, batch_size=2, shuffle=True, collate_fn=collate_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_x = iter(data_loader).next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 2, 3, 4],\n",
       "        [7, 8, 0, 0]])"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "def collate_fn(data):\n",
    "    data.sort(key = lambda x :len(x),reverse = True)\n",
    "    seq_len = [s.size(0) for s in data]\n",
    "    data = pad_sequence(data,batch_first= True)\n",
    "    data = pack_padded_sequence(data,seq_len,batch_first=True)\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PackedSequence(data=tensor([7, 9, 8]), batch_sizes=tensor([2, 1]), sorted_indices=None, unsorted_indices=None)"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = Mydata(train_x)\n",
    "data_loader = DataLoader(data, batch_size=2, shuffle=True, collate_fn=collate_fn)\n",
    "batch_x = iter(data_loader).next()\n",
    "batch_x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyData(Dataset):\n",
    "    def __init__(self, data):\n",
    "        self.data = data\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.data[idx]\n",
    "\n",
    "def collate_fn(data):\n",
    "    data.sort(key=lambda x: len(x), reverse=True)\n",
    "    seq_len = [s.size(0) for s in data]\n",
    "    data = pad_sequence(data, batch_first=True).float()    \n",
    "    data = data.unsqueeze(-1)\n",
    "    data = pack_padded_sequence(data, seq_len, batch_first=True)\n",
    "    return data\n",
    "\n",
    "a = torch.tensor([1,2,3,4])\n",
    "b = torch.tensor([5,6,7])\n",
    "c = torch.tensor([7,8])\n",
    "d = torch.tensor([9])\n",
    "train_x = [a, b, c, d]\n",
    "\n",
    "data = MyData(train_x)\n",
    "data_loader = DataLoader(data, batch_size=2, shuffle=True, collate_fn=collate_fn)\n",
    "batch_x = iter(data_loader).next()\n",
    "\n",
    "rnn = nn.LSTM(1, 4, 1, batch_first=True)\n",
    "h0 = torch.rand(1, 2, 4).float()\n",
    "c0 = torch.rand(1, 2, 4).float()\n",
    "out, (h1, c1) = rnn(batch_x, (h0, c0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
  },
  "kernelspec": {
   "display_name": "Python 3.6.9 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
